import pandas as pd
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

c = 1000

# Import date field as date type
dateparse = lambda x: datetime.strptime(x, '%Y-%m-%d')

main_data = pd.read_csv('../data/gpo_final_data/narratives_complete_with_metadata_manual_labels_rich_{0}.csv'.format(c),
                       parse_dates=['date'], date_parser=dateparse)

# Add year field
main_data['year'] = main_data['date'].dt.year

# Replace zero and inf OR
main_data['or'].replace([np.inf, -np.inf, 0], np.nan, inplace=True)

#########################################################################
# Plot narratives with highest, lowest, and most centered Odds Ratio

# Get top Republican narratives
rep_narratives = list(main_data.groupby(['narrative']).mean()['or'].nlargest(10).reset_index().sort_values('or', ascending = True).narrative)
dem_narratives = list(main_data.groupby(['narrative']).mean()['or'].nsmallest(10).index)
neutral_narratives = list(main_data.groupby(['narrative']).mean()['log_or'].abs().nsmallest(10).index)

narratives_to_plot = dem_narratives + neutral_narratives + rep_narratives

# Filter df for narratives to plot
_cleaned_main_data = main_data[main_data['narrative'].isin(narratives_to_plot)]

# Drop duplicates in terms of narrative
_unique_main_data = _cleaned_main_data.drop_duplicates(subset=['narrative'])

# Re-arrange df according to above list
_unique_main_data_sorted = pd.DataFrame(columns = ['narrative', 'or', 'log_or', 'log_or_ci_lower', 'log_or_ci_upper'])
for narrative in narratives_to_plot:
    _temp_narrative = _unique_main_data[_unique_main_data['narrative']==narrative][['narrative', 'or', 'log_or', 'log_or_ci_lower', 'log_or_ci_upper']]
    _unique_main_data_sorted = _unique_main_data_sorted.append(_temp_narrative, ignore_index=True)

# Plotting
_fig, _ax = plt.subplots(figsize=(47,25))
_ax.xaxis.label.set_fontsize(40)
_ax.set_xlabel('Log Odds Ratio')
_ax.set_ylabel('Narrative', visible=False)
_ax.yaxis.label.set_fontsize(40)
_ax.set_facecolor((.18, .31, .31, 0.0))
_ax.grid(color='#dfede2')
_ax.set_axisbelow(True)
_ax.set_aspect('equal', 'box')


_g = sns.scatterplot(data=_unique_main_data_sorted,
                 x= 'log_or',
                 y='narrative',
                 s=500,
                 marker='D',
                 color='#598ee3',
                 alpha=1.0,
                 ax=_ax)

# Find the x,y coordinates for each point
x_coords = []
y_coords = []
for point_pair in _g.collections:
    for x, y in point_pair.get_offsets():
        x_coords.append(x)
        y_coords.append(y)

# Error bars
_ax.errorbar(x=x_coords, y=y_coords,
             xerr = [np.abs(_unique_main_data_sorted['log_or_ci_lower'] - _unique_main_data_sorted['log_or']),
                     np.abs(_unique_main_data_sorted['log_or_ci_upper'] - _unique_main_data_sorted['log_or'])],
             fmt='o',
             ecolor='#598ee3')

# Vertical line at zero
plt.axvline(x=0.0, linestyle='--')

plt.xticks(fontsize=40)
plt.yticks(fontsize=40)
plt.savefig('../figures/Figure_H_1.pdf')
